{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "08fe2f4a",
   "metadata": {},
   "source": [
    "# Reflexion 开发指南\n",
    "\n",
    "本指南介绍了如何使用 LangGraph 构建具有自我反思能力的代理，该代理基于 Shinn 等人的 Reflexion 架构。该代理通过自我批判其任务响应，以生成更高质量的最终答案，代价是增加了执行时间。\n",
    "\n",
    "## 工作流程\n",
    "\n",
    "1. 安装与设置\n",
    "2. 设置 LangSmith 用于 LangGraph 开发\n",
    "3. 定义我们的 LLM\n",
    "4. Actor（具有反思功能）\n",
    "5. 构建工具\n",
    "6. 初始响应器\n",
    "7. 修订\n",
    "8. 创建工具节点\n",
    "9. 构建图形\n",
    "\n",
    "## 1. 安装与设置\n",
    "\n",
    "首先，安装 `langgraph`（用于框架）、`langchain_anthropic`（用于 LLM），以及 `tavily-python`（用于搜索引擎）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "846f6d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U --quiet langgraph langchain_anthropic\n",
    "%pip install -U --quiet tavily-python\n",
    "import getpass\n",
    "import os\n",
    "\n",
    "def _set_if_undefined(var: str) -> None:\n",
    "    if os.environ.get(var):\n",
    "        return\n",
    "    os.environ[var] = getpass.getpass(var)\n",
    "\n",
    "_set_if_undefined(\"ANTHROPIC_API_KEY\")\n",
    "_set_if_undefined(\"TAVILY_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05172b7f",
   "metadata": {},
   "source": [
    "我们将使用 Tavily Search 作为工具。你可以在 [这里](https://www.tavily.com/) 获取 API 密钥，或者替换为你选择的其他工具。\n",
    "\n",
    "## 2. 设置 LangSmith 用于 LangGraph 开发\n",
    "\n",
    "注册 LangSmith 以快速发现问题并提高 LangGraph 项目的性能。LangSmith 允许你使用跟踪数据来调试、测试和监控使用 LangGraph 构建的 LLM 应用程序——阅读更多关于如何开始的信息 [here](https://smith.langchain.com/)。\n",
    "\n",
    "## 3. 定义我们的 LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30b1468d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "llm = ChatOpenAI(model=\"gpt-4-turbo-preview\")\n",
    "\n",
    "# 你也可以使用其他提供商\n",
    "# from langchain_anthropic import ChatAnthropic\n",
    "\n",
    "# llm = ChatAnthropic(model=\"claude-3-sonnet-20240229\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7bfd532",
   "metadata": {},
   "source": [
    "## 4. Actor（具有反思功能）\n",
    "\n",
    "Reflexion 的主要组件是 \"actor\"，这是一个通过自我批判其响应并根据自我批判重新执行以改进的代理。其主要子组件包括：\n",
    "\n",
    "- 工具/工具执行\n",
    "- 初始响应器：生成任务的初始响应（和自我反思）\n",
    "- 修订者：基于先前的反思重新响应（和反思）\n",
    "\n",
    "我们首先定义工具执行上下文。\n",
    "\n",
    "--------------------\n",
    "\n",
    "## 5. 构建工具"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77519b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.tools.tavily_search import TavilySearchResults\n",
    "from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper\n",
    "\n",
    "search = TavilySearchAPIWrapper()\n",
    "tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "934ec558",
   "metadata": {},
   "source": [
    "## 6. 初始响应器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92930f29",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import HumanMessage, ToolMessage\n",
    "from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n",
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from pydantic import ValidationError\n",
    "# 注意：你必须使用 langchain-core >= 0.3 与 Pydantic v2\n",
    "from pydantic import BaseModel, Field\n",
    "\n",
    "class Reflection(BaseModel):\n",
    "    missing: str = Field(description=\"Critique of what is missing.\")\n",
    "    superfluous: str = Field(description=\"Critique of what is superfluous\")\n",
    "\n",
    "class AnswerQuestion(BaseModel):\n",
    "    answer: str = Field(description=\"~250 word detailed answer to the question.\")\n",
    "    reflection: Reflection = Field(description=\"Your reflection on the initial answer.\")\n",
    "    search_queries: list[str] = Field(\n",
    "        description=\"1-3 search queries for researching improvements to address the critique of your current answer.\"\n",
    "    )\n",
    "\n",
    "class ResponderWithRetries:\n",
    "    def __init__(self, runnable, validator):\n",
    "        self.runnable = runnable\n",
    "        self.validator = validator\n",
    "\n",
    "    def respond(self, state: list):\n",
    "        response = []\n",
    "        for attempt in range(3):\n",
    "            response = self.runnable.invoke(\n",
    "                {\"messages\": state}, {\"tags\": [f\"attempt:{attempt}\"]}\n",
    "            )\n",
    "            try:\n",
    "                self.validator.invoke(response)\n",
    "                return response\n",
    "            except ValidationError as e:\n",
    "                state = state + [\n",
    "                    response,\n",
    "                    ToolMessage(\n",
    "                        content=f\"{repr(e)}\n",
    "\n",
    "Pay close attention to the function schema.\n",
    "\n",
    "\"\n",
    "                        + self.validator.schema_json()\n",
    "                        + \" Respond by fixing all validation errors.\",\n",
    "                        tool_call_id=response.tool_calls[0][\"id\"],\n",
    "                    ),\n",
    "                ]\n",
    "        return response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3dda96f",
   "metadata": {},
   "source": [
    "## 7. 修订"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "411322d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "revise_instructions = '''Revise your previous answer using the new information.\n",
    "    - You should use the previous critique to add important information to your answer.\n",
    "        - You MUST include numerical citations in your revised answer to ensure it can be verified.\n",
    "        - Add a \"References\" section to the bottom of your answer (which does not count towards the word limit). In form of:\n",
    "            - [1] https://example.com\n",
    "            - [2] https://example.com\n",
    "    - You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 250 words.\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "021405f7",
   "metadata": {},
   "source": [
    "扩展初始答案的模式以包含引用。强制模型引用可以鼓励扎实的响应。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8774c10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 扩展初始答案的模式以包含引用。\n",
    "# 强制引用可以鼓励模型生成有根据的响应\n",
    "class ReviseAnswer(AnswerQuestion):\n",
    "    references: list[str] = Field(description=\"Citations motivating your updated answer.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78120cb8",
   "metadata": {},
   "source": [
    "## 8. 创建工具节点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a884b113",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import StructuredTool\n",
    "from langgraph.prebuilt import ToolNode\n",
    "\n",
    "def run_queries(search_queries: list[str], **kwargs):\n",
    "    return tavily_tool.batch([{\"query\": query} for query in search_queries])\n",
    "\n",
    "tool_node = ToolNode(\n",
    "    [\n",
    "        StructuredTool.from_function(run_queries, name=AnswerQuestion.__name__),\n",
    "        StructuredTool.from_function(run_queries, name=ReviseAnswer.__name__),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "423a43de",
   "metadata": {},
   "source": [
    "## 9. 构建图形"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "479405a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "from langgraph.graph import END, StateGraph, START\n",
    "from langgraph.graph.message import add_messages\n",
    "from typing import Annotated\n",
    "from typing_extensions import TypedDict\n",
    "\n",
    "class State(TypedDict):\n",
    "    messages: Annotated[list, add_messages]\n",
    "\n",
    "MAX_ITERATIONS = 5\n",
    "builder = StateGraph(State)\n",
    "builder.add_node(\"draft\", first_responder.respond)\n",
    "\n",
    "builder.add_node(\"execute_tools\", tool_node)\n",
    "builder.add_node(\"revise\", revisor.respond)\n",
    "builder.add_edge(\"draft\", \"execute_tools\")\n",
    "builder.add_edge(\"execute_tools\", \"revise\")\n",
    "\n",
    "def _get_num_iterations(state: list):\n",
    "    i = 0\n",
    "    for m in state[::-1]:\n",
    "        if m.type not in {\"tool\", \"ai\"}:\n",
    "            break\n",
    "        i += 1\n",
    "    return i\n",
    "\n",
    "def event_loop(state: list) -> Literal[\"execute_tools\", \"__end__\"]:\n",
    "    num_iterations = _get_num_iterations(state)\n",
    "    if num_iterations > MAX_ITERATIONS:\n",
    "        return END\n",
    "    return \"execute_tools\"\n",
    "\n",
    "builder.add_conditional_edges(\"revise\", event_loop)\n",
    "builder.add_edge(START, \"draft\")\n",
    "graph = builder.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bfc9f68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "try:\n",
    "    display(Image(graph.get_graph().draw_mermaid_png()))\n",
    "except Exception:\n",
    "    pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
